/*    Copyright 2010 Tobias Marschall
 *
 *    This file is part of MoSDi.
 *
 *    MoSDi is free software: you can redistribute it and/or modify
 *    it under the terms of the GNU General Public License as published by
 *    the Free Software Foundation, either version 3 of the License, or
 *    (at your option) any later version.
 *
 *    MoSDi is distributed in the hope that it will be useful,
 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *    GNU General Public License for more details.
 *
 *    You should have received a copy of the GNU General Public License
 *    along with MoSDi.  If not, see <http://www.gnu.org/licenses/>.
 */

package mosdi.tests;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

import junit.framework.TestCase;
import mosdi.fa.Alphabet;
import mosdi.index.SuffixTree;
import mosdi.index.SuffixTreeWalker;
import mosdi.matching.HorspoolMatcher;
import mosdi.util.BitArray;
import mosdi.util.Iupac;
import mosdi.util.NamedSequence;
import mosdi.util.SequenceUtils;
import mosdi.util.iterators.LexicographicalIterator;
import mosdi.util.iterators.StringIterator;

public class SuffixTreeTest extends TestCase {
	
	public static final String text = "TAACTAAGATTACCAGCCATGGTGGAGTGGATTAATAGCAAAGAAGTCGAGCCCCGAATAAATAAAGGGAGAAAGGGCTTATCCGCGCTGATGCGCGCGG";
	
	public void testSuffixTree() {
		String s = "mississippi";
		Alphabet alphabet = new Alphabet(s);
		SuffixTree st = new SuffixTree();
//		st.addNodeConstructionListener(new SuffixTree.NodeConstructionListener() {
//			public void nodeConstructed(int[] suffixStartPositions, int lcp, int parentLcp) {
//				// System.out.println("New node: "+Arrays.toString(suffixStartPositions)+", lcp "+lcp);
//				System.out.println("New node: "+"mississippi$".substring(suffixStartPositions[0], suffixStartPositions[0]+lcp)+" ("+"mississippi$".substring(suffixStartPositions[0]+parentLcp, suffixStartPositions[0]+lcp)+")");
//			}
//		});
		st.buildTree(alphabet.buildIndexArray(s,true), alphabet.size());
		assertEquals(2, st.countMatches(alphabet.buildIndexArray("ssi")));
		assertEquals(1, st.countMatches(alphabet.buildIndexArray("mi")));
		assertEquals(0, st.countMatches(alphabet.buildIndexArray("psimi")));
		assertEquals(2, st.countMatches(alphabet.buildIndexArray("p")));
		assertEquals(1, st.countMatches(alphabet.buildIndexArray("ppi")));
		assertEquals(4, st.countMatches(alphabet.buildIndexArray("i")));
	}

	public void testSuffixTree2() {
		Alphabet dnaAlphabet = Alphabet.getDnaAlphabet();
		SuffixTree st = new SuffixTree(dnaAlphabet.buildIndexArray(text,true), dnaAlphabet.size());
		assertEquals(30, st.countMatches(dnaAlphabet.buildIndexArray("G")));
		assertEquals(3, st.countMatches(dnaAlphabet.buildIndexArray("AGA")));
		assertEquals(1, st.countMatches(dnaAlphabet.buildIndexArray(text)));
		assertEquals(4, st.countMatches(Iupac.toGeneralizedString("WCNV")));
		assertEquals(6, st.countMatches(Iupac.toGeneralizedString("KMNNNNSW")));
		assertEquals(text.length()-8+1, st.countMatches(Iupac.toGeneralizedString("NNNNNNNN")));
		SuffixTree st2 = new SuffixTree(dnaAlphabet.buildIndexArray(text,true), dnaAlphabet.size(), 4);
		assertTrue(st.getNodeCount()>st2.getNodeCount());
		assertEquals(30, st2.countMatches(dnaAlphabet.buildIndexArray("G")));
		assertEquals(3, st2.countMatches(dnaAlphabet.buildIndexArray("AGA")));
		assertEquals(1, st2.countMatches(dnaAlphabet.buildIndexArray(text)));
		assertEquals(4, st2.countMatches(Iupac.toGeneralizedString("WCNV")));
		assertEquals(text.length(), st2.countMatches(Iupac.toGeneralizedString("N")));
		assertEquals(text.length()-2+1, st2.countMatches(Iupac.toGeneralizedString("NN")));
		assertEquals(text.length()-3+1, st2.countMatches(Iupac.toGeneralizedString("NNN")));
		assertEquals(text.length()-4+1, st2.countMatches(Iupac.toGeneralizedString("NNNN")));
	}

	public void testSuffixAnnotations() {
		Alphabet dnaAlphabet = Alphabet.getDnaAlphabet();
		SuffixTree st = new SuffixTree(dnaAlphabet.buildIndexArray(text,true), dnaAlphabet.size());
		int[] nodes = st.walk(Iupac.toGeneralizedString("NNNN"));
		int[] maxMatchCountAnnotation = st.calcMaxMatchCountAnnotation(4);
		assertEquals(st.getNodeCount(), maxMatchCountAnnotation.length);
		int[] occurrenceCountAnnotation = st.calcOccurrenceCountAnnotation(); 
		assertEquals(st.getNodeCount(), occurrenceCountAnnotation.length);
		SuffixTree st2 = new SuffixTree(dnaAlphabet.buildIndexArray(text,true), dnaAlphabet.size(),4);
		int[] maxMatchCountAnnotation2 = st2.calcMaxMatchCountAnnotation(4);
		assertEquals(st2.getNodeCount(), maxMatchCountAnnotation2.length);
		int[] occurrenceCountAnnotation2 = st2.calcOccurrenceCountAnnotation();
		assertEquals(st2.getNodeCount(), occurrenceCountAnnotation2.length);
		int[] nodes2 = st2.walk(Iupac.toGeneralizedString("NNNN"));
		assertEquals(nodes.length, nodes2.length);
		for (int i=0; i<nodes.length; ++i) {
			assertEquals(maxMatchCountAnnotation[nodes[i]],maxMatchCountAnnotation2[nodes2[i]]);
			assertEquals(occurrenceCountAnnotation[nodes[i]],occurrenceCountAnnotation2[nodes2[i]]);
		}
	}
	
	private BitArray findMatches(Alphabet alphabet, List<NamedSequence> l, int[] s, boolean considerReverse) {
		BitArray result = new BitArray(l.size());
		HorspoolMatcher matcher = new HorspoolMatcher(alphabet.size(), s);
		Alphabet iupacAlphabet = Alphabet.getIupacAlphabet();
		int[] revIupac = Iupac.reverseComplementary(iupacAlphabet.buildIndexArray(alphabet.buildString(s)));
		int[] rev = alphabet.buildIndexArray(iupacAlphabet.buildString(revIupac));
		HorspoolMatcher revMatcher = new HorspoolMatcher(alphabet.size(), rev);
		for (int i=0; i<l.size(); ++i) {
			result.set(i, matcher.findMatches(l.get(i).getSequence())>0);
		}
		if (considerReverse) {
			for (int i=0; i<l.size(); ++i) {
				if (revMatcher.findMatches(l.get(i).getSequence())>0) result.set(i, 1);
			}
		}
		return result;
	}
	
	private void assertSequenceCountAnnotation(Alphabet alphabet, SuffixTree st, List<NamedSequence> l, BitArray[] annotation, int maxDepth, boolean considerReverse) {
		int maxLength = 0;
		for (NamedSequence ns : l) maxLength = Math.max(maxLength, ns.length());
		SuffixTreeWalker walker = st.getWalker();
		LexicographicalIterator iterator = new StringIterator(alphabet.size(), maxLength);
		while (iterator.hasNext()) {
			int[] s = iterator.next();
			int i = iterator.getLeftmostChangedPosition();
			walker.backward(i);
			for (; i<s.length; ++i) {
				int[] nodes = walker.forward(s[i]);
				assertTrue(nodes.length<2);
				int[] nodeLabel = Arrays.copyOfRange(s, 0, i+1);
				BitArray matches = findMatches(alphabet, l, nodeLabel,considerReverse);
				if (nodes.length==0) {
					if (nodeLabel.length<=maxDepth) {
						assertTrue(matches.allZero());
					}
					iterator.skip(i);
					break;
				} else {
					assertTrue(matches.equals(annotation[nodes[0]]));
				}
			}
		}
	}
	
	public void testSequenceOccurrenceAnnotations() {
		Alphabet dnaAlphabet = Alphabet.getDnaAlphabet();
		List<NamedSequence> l = new ArrayList<NamedSequence>();
		l.add(new NamedSequence("seq1", dnaAlphabet.buildIndexArray("ATTCACAAGCACAACGCATAAAAGGACGACCTGGCCTGCCAAGTGCAACGGCGAAGTTTTCGAACGTCGGTGCGGGGCCGTGTTGCCCGACTCATCATCA")));
		l.add(new NamedSequence("seq2", dnaAlphabet.buildIndexArray("GCTGGAACGACGCCCGCGGTCGTTTCACCGCAGGGGCGGCCATAGGATGTCAAGCCGGACACGATGTTTGCCCCGTAAAAGGATCCGACCGGGCCGAAGG")));
		l.add(new NamedSequence("seq3", dnaAlphabet.buildIndexArray("CATAAGAGGAAACGAGGTGGCCCGACGCCCTCGACGCGTCGCGCCGTCGTAAGAGGACCTCATCGCCGGTGGAAGTGCGCACAGACCTCCCGCGACAGTC")));
		l.add(new NamedSequence("seq4", dnaAlphabet.buildIndexArray("AACGCCCGGGCGGTCCGGGTGACTGTGGTGGTGAGGTCGTGACAGGTCGGCGGCTCCCGTTGAGCGCGCGTGGAGGTCACTCGCTGGCCTGGTCCGCCTC")));
		l.add(new NamedSequence("seq5", dnaAlphabet.buildIndexArray("CCCAACGCGGGCGGGGAACATCCACCTAAGAGGATTCGGCGGCTGCATTCGGCCGCCGACGGCCACCGTAGGGTCACCGTTCTGCTCCGGGACCGTTCGC")));
		l.add(new NamedSequence("seq6", dnaAlphabet.buildIndexArray("CGGGTACTGGTCGAAATATCACCGCTGGGTTTCGACCTGCGAGTAAGAGGACGGCGGCTCGGAACAAACCCGCCTGGATCGACGCCCCGGCCCCGCAGTT")));
		l.add(new NamedSequence("seq7", dnaAlphabet.buildIndexArray("CCTGAAGGTTGGTATAGTCTGCTAAGACGACTCCAGCATGAGCGGTGAGGACACCAGCCAGTAGCACCGAACAACGGTTGCATCGCCGGCCGCCCCGCGG")));
		l.add(new NamedSequence("seq8", dnaAlphabet.buildIndexArray("GGTTGGAGTGACCAATAGTCCGATGCGATCCGGTGGGCGGGATCGCCCGCATGCTATCTCAGGCTCTTCGCTGGCGTGCCACTTGGCCAAAGAGGTGGTG")));
		int[] s = SequenceUtils.concatSequences(l, false);
		SuffixTree st = new SuffixTree(s, dnaAlphabet.size());
		BitArray[] annotation = st.calcSequenceOccurrenceAnnotation(SequenceUtils.appendMinusOnes(l,false));
		assertSequenceCountAnnotation(dnaAlphabet, st, l, annotation, Integer.MAX_VALUE, false);
		for (int maxDepth=0; maxDepth<10; ++maxDepth) {
			st = new SuffixTree(s, dnaAlphabet.size(), maxDepth);
			annotation = st.calcSequenceOccurrenceAnnotation(SequenceUtils.appendMinusOnes(l,false));
			assertSequenceCountAnnotation(dnaAlphabet, st, l, annotation, maxDepth, false);			
		}

		s = SequenceUtils.concatSequences(l, true);
		st = new SuffixTree(s, dnaAlphabet.size());
		annotation = st.calcSequenceOccurrenceAnnotation(SequenceUtils.appendMinusOnes(l,true));
		assertSequenceCountAnnotation(dnaAlphabet, st, l, annotation, Integer.MAX_VALUE, true);
		for (int maxDepth=0; maxDepth<10; ++maxDepth) {
			st = new SuffixTree(s, dnaAlphabet.size(), maxDepth);
			annotation = st.calcSequenceOccurrenceAnnotation(SequenceUtils.appendMinusOnes(l,true));
			assertSequenceCountAnnotation(dnaAlphabet, st, l, annotation, maxDepth, true);			
		}
	}
	
}
